/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 * this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 * contributors may be used to endorse or promote products derived from this
 * software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#pragma once

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <array>
#include <atomic>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>

#include "paddle/extension.h"

#define CHECK_CUDA(x) PD_CHECK(x.is_gpu(), #x " must be a CUDA tensor")
#define CHECK_IS_HALF(x)                              \
  PD_CHECK(x.dtype() == paddle::DataType::FLOAT16, #x \
           " must be a float16 "                      \
           "tensor")

static constexpr uint32_t MIN_GPU_ARCH = 70;

using network_precision_t = __half;

enum class Activation {
  ReLU,
  Exponential,
  Sine,
  Sigmoid,
  Squareplus,
  Softplus,
  None,
};

static constexpr float PI = 3.14159265358979323846f;
static constexpr float SQRT2 = 1.41421356237309504880f;
static constexpr float K_ACT = 10.0f;

__host__ __device__ inline float logistic(const float x) {
  return 1.0f / (1.0f + expf(-x));
}

__host__ __device__ inline float logit(const float x) {
  return -logf(1.0f / (fminf(fmaxf(x, 1e-9f), 1.0f - 1e-9f)) - 1.0f);
}

inline std::atomic<size_t>& total_n_bytes_allocated() {
  static std::atomic<size_t> s_total_n_bytes_allocated{0};
  return s_total_n_bytes_allocated;
}

/// Checks the result of a cudaXXXXXX call and throws an error on failure
#define CUDA_CHECK_THROW(x)                                      \
  do {                                                           \
    cudaError_t result = x;                                      \
    if (result != cudaSuccess)                                   \
      throw std::runtime_error(                                  \
          std::string("CUDA Error: " #x " failed with error ") + \
          cudaGetErrorString(result));                           \
  } while (0)

/// Checks the result of a cudaXXXXXX call and prints an error on failure
#define CUDA_CHECK_PRINT(x)                                 \
  do {                                                      \
    cudaError_t result = x;                                 \
    if (result != cudaSuccess)                              \
      std::cout << "CUDA Error: " #x " failed with error "  \
                << cudaGetErrorString(result) << std::endl; \
  } while (0)

#define DEBUG_GUARD_SIZE 0

/// Managed memory on the Device
template <class T>
class GPUMemory {
 private:
  T* m_data = nullptr;
  size_t m_size = 0;  // Number of elements
  bool m_owned = true;

 public:
  GPUMemory() {}

  GPUMemory<T>& operator=(GPUMemory<T>&& other) {
    std::swap(m_data, other.m_data);
    std::swap(m_size, other.m_size);
    return *this;
  }

  GPUMemory(GPUMemory<T>&& other) { *this = std::move(other); }

  __host__ __device__ GPUMemory(const GPUMemory<T>& other)
      : m_data{other.m_data}, m_size{other.m_size}, m_owned{false} {}

  void check_guards() const {
#if DEBUG_GUARD_SIZE > 0
    if (!m_data) return;
    uint8_t buf[DEBUG_GUARD_SIZE];
    const uint8_t* rawptr = (const uint8_t*)m_data;
    cudaMemcpy(buf, rawptr - DEBUG_GUARD_SIZE, DEBUG_GUARD_SIZE,
               cudaMemcpyDeviceToHost);
    for (int i = 0; i < DEBUG_GUARD_SIZE; ++i)
      if (buf[i] != 0xff) {
        printf(
            "TRASH BEFORE BLOCK offset %d data %p, read 0x%02x expected "
            "0xff!\n",
            i, m_data, buf[i]);
        break;
      }
    cudaMemcpy(buf, rawptr + m_size * sizeof(T), DEBUG_GUARD_SIZE,
               cudaMemcpyDeviceToHost);
    for (int i = 0; i < DEBUG_GUARD_SIZE; ++i)
      if (buf[i] != 0xfe) {
        printf(
            "TRASH AFTER BLOCK offset %d data %p, read 0x%02x expected 0xfe!\n",
            i, m_data, buf[i]);
        break;
      }
#endif
  }

  void allocate_memory(size_t n_bytes) {
    if (n_bytes == 0) {
      return;
    }

#ifdef TCNN_VERBOSE_MEMORY_ALLOCS
    std::cout << "GPUMemory: Allocating " << bytes_to_string(n_bytes) << "."
              << std::endl;
#endif

    uint8_t* rawptr = nullptr;
    CUDA_CHECK_THROW(cudaMalloc(&rawptr, n_bytes + DEBUG_GUARD_SIZE * 2));
#if DEBUG_GUARD_SIZE > 0
    CUDA_CHECK_THROW(cudaMemset(rawptr, 0xff, DEBUG_GUARD_SIZE));
    CUDA_CHECK_THROW(cudaMemset(rawptr + n_bytes + DEBUG_GUARD_SIZE, 0xfe,
                                DEBUG_GUARD_SIZE));
#endif
    if (rawptr) rawptr += DEBUG_GUARD_SIZE;
    m_data = (T*)(rawptr);
    total_n_bytes_allocated() += n_bytes;
  }

  void free_memory() {
    if (!m_data) {
      return;
    }

    uint8_t* rawptr = (uint8_t*)m_data;
    if (rawptr) rawptr -= DEBUG_GUARD_SIZE;
    CUDA_CHECK_THROW(cudaFree(rawptr));

    total_n_bytes_allocated() -= get_bytes();

    m_data = nullptr;
  }

  /// Allocates memory for size items of type T
  GPUMemory(const size_t size) { resize(size); }

  /// Frees memory again
  __host__ __device__ ~GPUMemory() {
#ifndef __CUDA_ARCH__
    if (!m_owned) {
      return;
    }

    try {
      if (m_data) {
        free_memory();
        m_size = 0;
      }
    } catch (std::runtime_error error) {
      // Don't need to report on memory-free problems when the driver is
      // shutting down.
      if (std::string{error.what()}.find("driver shutting down") ==
          std::string::npos) {
        fprintf(stderr, "Could not free memory: %s\n", error.what());
      }
    }
#endif
  }

  /** @name Resizing/enlargement
   *  @{
   */
  /// Resizes the array to the exact new size, even if it is already larger
  void resize(const size_t size) {
    if (!m_owned) {
      throw std::runtime_error("Cannot resize non-owned memory.");
    }

    if (m_size != size) {
      if (m_size) {
        try {
          free_memory();
        } catch (std::runtime_error error) {
          throw std::runtime_error(std::string("Could not free memory: ") +
                                   error.what());
        }
      }

      if (size > 0) {
        try {
          allocate_memory(size * sizeof(T));
        } catch (std::runtime_error error) {
          throw std::runtime_error(std::string("Could not allocate memory: ") +
                                   error.what());
        }
      }

      m_size = size;
    }
  }

  /// Enlarges the array if its size is smaller
  void enlarge(const size_t size) {
    if (size > m_size) {
      resize(size);
    }
  }
  /** @} */

  /** @name Memset
   *  @{
   */
  /// Sets the memory of the first num_elements to value
  void memset(const int value, const size_t num_elements,
              const size_t offset = 0) {
    if (num_elements + offset > m_size) {
      throw std::runtime_error(
          "Could not set memory: Number of elements larger than allocated "
          "memory");
    }

    try {
      CUDA_CHECK_THROW(
          cudaMemset(m_data + offset, value, num_elements * sizeof(T)));
    } catch (std::runtime_error error) {
      throw std::runtime_error(std::string("Could not set memory: ") +
                               error.what());
    }
  }

  /// Sets the memory of the all elements to value
  void memset(const int value) { memset(value, m_size); }
  /** @} */

  /** @name Copy operations
   *  @{
   */
  /// Copy data of num_elements from the raw pointer on the host
  void copy_from_host(const T* host_data, const size_t num_elements) {
    try {
      CUDA_CHECK_THROW(cudaMemcpy(data(), host_data, num_elements * sizeof(T),
                                  cudaMemcpyHostToDevice));
    } catch (std::runtime_error error) {
      throw std::runtime_error(std::string("Could not copy from host: ") +
                               error.what());
    }
  }

  /// Copy num_elements from the host vector
  void copy_from_host(const std::vector<T>& data, const size_t num_elements) {
    if (data.size() < num_elements) {
      throw std::runtime_error(
          std::string("Trying to copy ") + std::to_string(num_elements) +
          std::string(" elements, but vector size is only ") +
          std::to_string(data.size()));
    }
    copy_from_host(data.data(), num_elements);
  }

  /// Copies data from the raw host pointer to fill the entire array
  void copy_from_host(const T* data) { copy_from_host(data, m_size); }

  /// Copies num_elements of data from the raw host pointer after enlarging the
  /// array so that everything fits in
  void enlarge_and_copy_from_host(const T* data, const size_t num_elements) {
    enlarge(num_elements);
    copy_from_host(data, num_elements);
  }

  /// Copies num_elements from the host vector after enlarging the array so that
  /// everything fits in
  void enlarge_and_copy_from_host(const std::vector<T>& data,
                                  const size_t num_elements) {
    enlarge_and_copy_from_host(data.data(), num_elements);
  }

  /// Copies the entire host vector after enlarging the array so that everything
  /// fits in
  void enlarge_and_copy_from_host(const std::vector<T>& data) {
    enlarge_and_copy_from_host(data.data(), data.size());
  }

  /// Copies num_elements of data from the raw host pointer after resizing the
  /// array
  void resize_and_copy_from_host(const T* data, const size_t num_elements) {
    resize(num_elements);
    copy_from_host(data, num_elements);
  }

  /// Copies num_elements from the host vector after resizing the array
  void resize_and_copy_from_host(const std::vector<T>& data,
                                 const size_t num_elements) {
    resize_and_copy_from_host(data.data(), num_elements);
  }

  /// Copies the entire host vector after resizing the array
  void resize_and_copy_from_host(const std::vector<T>& data) {
    resize_and_copy_from_host(data.data(), data.size());
  }

  /// Copies the entire host vector to the device. Fails if there is not enough
  /// space available.
  void copy_from_host(const std::vector<T>& data) {
    if (data.size() < m_size) {
      throw std::runtime_error(
          std::string("Trying to copy ") + std::to_string(m_size) +
          std::string(" elements, but vector size is only ") +
          std::to_string(data.size()));
    }
    copy_from_host(data.data(), m_size);
  }

  /// Copies num_elements of data from the raw host pointer to the device. Fails
  /// if there is not enough space available.
  void copy_to_host(T* host_data, const size_t num_elements) const {
    if (num_elements > m_size) {
      throw std::runtime_error(
          std::string("Trying to copy ") + std::to_string(num_elements) +
          std::string(" elements, but vector size is only ") +
          std::to_string(m_size));
    }
    try {
      CUDA_CHECK_THROW(cudaMemcpy(host_data, data(), num_elements * sizeof(T),
                                  cudaMemcpyDeviceToHost));
    } catch (std::runtime_error error) {
      throw std::runtime_error(std::string("Could not copy to host: ") +
                               error.what());
    }
  }

  /// Copies num_elements from the device to a vector on the host
  void copy_to_host(std::vector<T>& data, const size_t num_elements) const {
    if (data.size() < num_elements) {
      throw std::runtime_error(
          std::string("Trying to copy ") + std::to_string(num_elements) +
          std::string(" elements, but vector size is only ") +
          std::to_string(data.size()));
    }
    copy_to_host(data.data(), num_elements);
  }

  /// Copies num_elements from the device to a raw pointer on the host
  void copy_to_host(T* data) const { copy_to_host(data, m_size); }

  /// Copies all elements from the device to a vector on the host
  void copy_to_host(std::vector<T>& data) const {
    if (data.size() < m_size) {
      throw std::runtime_error(
          std::string("Trying to copy ") + std::to_string(m_size) +
          std::string(" elements, but vector size is only ") +
          std::to_string(data.size()));
    }
    copy_to_host(data.data(), m_size);
  }

  /// Copies data from another device array to this one, automatically resizing
  /// it
  void copy_from_device(const GPUMemory<T>& other) {
    if (m_size != other.m_size) {
      resize(other.m_size);
    }

    try {
      CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data, m_size * sizeof(T),
                                  cudaMemcpyDeviceToDevice));
    } catch (std::runtime_error error) {
      throw std::runtime_error(std::string("Could not copy from device: ") +
                               error.what());
    }
  }

  /// Copies size elements from another device array to this one, automatically
  /// resizing it
  void copy_from_device(const GPUMemory<T>& other, const size_t size) {
    if (m_size < size) {
      resize(size);
    }

    try {
      CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data, size * sizeof(T),
                                  cudaMemcpyDeviceToDevice));
    } catch (std::runtime_error error) {
      throw std::runtime_error(std::string("Could not copy from device: ") +
                               error.what());
    }
  }

  // Created an (owned) copy of the data
  GPUMemory<T> copy() const {
    GPUMemory<T> result{m_size};
    result.copy_from_device(*this);
    return result;
  }

  T* data() const {
    check_guards();
    return m_data;
  }

  __host__ __device__ T& operator[](size_t idx) const {
#ifdef DEBUG_BUFFER_OVERRUN
    if (idx > m_size) {
      printf("WARNING: buffer overrun of %p at idx %zu\n", idx);
    }
#endif
    return m_data[idx];
  }

  __host__ __device__ T& operator[](uint32_t idx) const {
#ifdef DEBUG_BUFFER_OVERRUN
    if (idx > m_size) {
      printf("WARNING: buffer overrun of %p at idx %u\n", idx);
    }
#endif
    return m_data[idx];
  }

  size_t get_num_elements() const { return m_size; }

  size_t size() const { return get_num_elements(); }

  size_t get_bytes() const { return m_size * sizeof(T); }

  size_t bytes() const { return get_bytes(); }
};

inline std::string bytes_to_string(size_t bytes) {
  std::array<std::string, 7> suffixes = {
      {"B", "KB", "MB", "GB", "TB", "PB", "EB"}};

  double count = (double)bytes;
  uint32_t i = 0;
  for (; i < suffixes.size() && count >= 1024; ++i) {
    count /= 1024;
  }

  std::ostringstream oss;
  oss.precision(3);
  oss << count << " " << suffixes[i];
  return oss.str();
}

template <typename T, typename fragment_t>
__host__ __device__ void warp_activation(Activation activation,
                                         const fragment_t& frag,
                                         fragment_t& result) {
  switch (activation) {
    case Activation::ReLU:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = frag.x[t] * (T)((T)frag.x[t] > (T)0.0f);
      }
      return;
    case Activation::Exponential:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = (T)(expf((float)frag.x[t]));
      }
      return;
    case Activation::Sine:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = (T)(sinf((float)frag.x[t]));
      }
      return;
    case Activation::Sigmoid:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = (T)(logistic((float)frag.x[t]));
      }
      return;
    case Activation::Squareplus:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        float x = (float)frag.x[t] * K_ACT;
        result.x[t] = (T)(0.5f * (x + sqrtf(x * x + 4)) / K_ACT);
      }
      return;
    case Activation::Softplus:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = (T)(logf(expf((float)frag.x[t] * K_ACT) + 1.0f) / K_ACT);
      }
      return;
    case Activation::None:
      result = frag;
      return;
    default:
      // Unsupported activation
      // assert(false); // Commented out due to isolated strange side-effects on
      // Windows
      return;
  }
}

template <typename T, typename fragment_t>
__host__ __device__ fragment_t warp_activation(Activation activation,
                                               const fragment_t& frag) {
  fragment_t result;
  warp_activation<T>(activation, frag, result);
  return result;
}

template <typename T, typename fragment_t, typename forward_fragment_t>
__host__ __device__ void warp_activation_backward_in(
    Activation activation, const fragment_t& frag,
    const forward_fragment_t& forward_frag_in, fragment_t& result) {
  switch (activation) {
    case Activation::ReLU:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = frag.x[t] * (T)(forward_frag_in.x[t] > (T)0.0f);
      }
      return;
    case Activation::Exponential:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = frag.x[t] * (T)(expf(forward_frag_in.x[t]));
      }
      return;
    case Activation::Sine:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = frag.x[t] * (T)(cosf(forward_frag_in.x[t]));
      }
      return;
    case Activation::Sigmoid:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        float x = logistic(forward_frag_in.x[t]);
        result.x[t] = frag.x[t] * (T)(x * (1.0f - x));
      }
      return;
    case Activation::Squareplus:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        float x = (float)forward_frag_in.x[t] * K_ACT;
        float y = 0.5f * (x + sqrtf(x * x + 4));
        result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1));
      }
      return;
    case Activation::Softplus:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        float tmp = expf((float)frag.x[t] * K_ACT);
        result.x[t] = frag.x[t] * (T)(tmp / (tmp + 1));
      }
      return;
    case Activation::None:
      result = frag;
      return;
    default:
      // Unsupported activation
      // assert(false); // Commented out due to isolated strange side-effects on
      // Windows
      return;
  }
}

template <typename T, typename fragment_t, typename forward_fragment_t>
__host__ __device__ fragment_t
warp_activation_backward_in(Activation activation, const fragment_t& frag,
                            const forward_fragment_t& forward_frag_in) {
  fragment_t result;
  warp_activation_backward_in<T>(activation, frag, forward_frag_in, result);
  return result;
}

template <typename T, typename fragment_t, typename forward_fragment_t>
__host__ __device__ void warp_activation_backward(
    Activation activation, const fragment_t& frag,
    const forward_fragment_t& forward_frag, fragment_t& result) {
  switch (activation) {
    case Activation::ReLU:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f);
      }
      return;
    case Activation::Exponential:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] = frag.x[t] * forward_frag.x[t];
      }
      return;
    case Activation::Sine:
      // Sine requires stored pre-activations, which we don't have. We only
      // write out the post-activations.
      // assert(false); // Commented out due to isolated strange side-effects on
      // Windows
      return;
    case Activation::Sigmoid:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] =
            frag.x[t] * (T)(forward_frag.x[t] * ((T)1.0f - forward_frag.x[t]));
      }
      return;
    case Activation::Squareplus:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        float y = (float)forward_frag.x[t] * K_ACT;
        result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1));
      }
      return;
    case Activation::Softplus:
#pragma unroll
      for (int t = 0; t < result.num_elements; t++) {
        result.x[t] =
            frag.x[t] * (T)(1.0f - expf(-(float)forward_frag.x[t] * K_ACT));
      }
      return;
    case Activation::None:
      result = frag;
      return;
    default:
      // Unsupported activation
      // assert(false); // Commented out due to isolated strange side-effects on
      // Windows
      return;
  }
}

template <typename T, typename fragment_t, typename forward_fragment_t>
__host__ __device__ fragment_t
warp_activation_backward(Activation activation, const fragment_t& frag,
                         const forward_fragment_t& forward_frag) {
  fragment_t result;
  warp_activation_backward<T>(activation, frag, forward_frag, result);
  return result;
}
